{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Week 9: Normalising flows pt 1 - bijectors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import tensorflow_probability as tfp\n",
    "tfd = tfp.distributions\n",
    "tfb = tfp.bijectors\n",
    "\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true
   },
   "source": [
    "## Tensorflow bijectors"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Base distribution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "base_dist = tfd.MultivariateNormalDiag(loc=tf.zeros([2], tf.float32), scale_diag=tf.constant([1, 1], tf.float32))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "SAMPLE_BATCH_SIZE = 512"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z = base_dist.sample(SAMPLE_BATCH_SIZE)\n",
    "print(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess = tf.InteractiveSession()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "z_samples = z.eval()\n",
    "print(type(z_samples))\n",
    "print(z_samples.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(5, 5))\n",
    "plt.scatter(z_samples[:, 0], z_samples[:, 1], s=10)\n",
    "plt.title(\"Base distribution: standard normal\")\n",
    "plt.xlim([-4, 4])\n",
    "plt.ylim([-4, 4])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Transform the distribution"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "A Bijector is used to transform distributions. Bijectors are the building blocks for a normalising flow. \n",
    "They are characterised by the following three main methods:\n",
    "    1. forward\n",
    "    2. inverse\n",
    "    3. log_det_jacobian\n",
    "\n",
    "Conventionally, think of the `forward` operation as acting on the base distribution (generate samples) and the `inverse` operation is used to calculate probabilities.\n",
    "\n",
    "For example, the Affine Bijector:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "affine_bijector = tfb.Affine(shift=[1., -1.], scale_diag=[0.5, 1.5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fwd_z = affine_bijector.forward(z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "z_samples, x_samples = sess.run([z, fwd_z])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "fig = plt.figure(figsize=(12, 5))\n",
    "ax = fig.add_subplot(121)\n",
    "ax2 = fig.add_subplot(122)\n",
    "\n",
    "ax.scatter(z_samples[:, 0], z_samples[:, 1], s=10)\n",
    "ax.set_title(\"Base distribution: standard normal\")\n",
    "ax.set_xlim([-5, 5])\n",
    "ax.set_ylim([-5, 5])\n",
    "\n",
    "ax2.scatter(x_samples[:, 0], x_samples[:, 1], s=10, color='r')\n",
    "ax2.set_title(\"Transformed distribution: shift [1, -1], scale [0.5, 1.5]\")\n",
    "ax2.set_xlim([-5, 5])\n",
    "ax2.set_ylim([-5, 5])\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "fwd_inv_z = affine_bijector.inverse(fwd_z)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "latents = np.random.random((SAMPLE_BATCH_SIZE, 2))\n",
    "print(np.allclose(latents, sess.run(fwd_inv_z, feed_dict={z: latents})))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Computing probabilities"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x = tf.placeholder(shape=(1, 2), dtype=tf.float32)\n",
    "\n",
    "log_det_dzdx = affine_bijector.inverse_log_det_jacobian(x, event_ndims=1)\n",
    "log_det_dzdx"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "inv_x = affine_bijector.inverse(x)\n",
    "inv_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "log_prob_inv_x = base_dist.log_prob(inv_x)\n",
    "log_prob_inv_x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "x_fixed_sample = np.array([[1., -1.]])  # Mode of the transformed distribution\n",
    "\n",
    "sess.run(log_det_dzdx, feed_dict={x: x_fixed_sample})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check: Jacobian determinant is just the product of scaling factors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "- np.log(0.5) - np.log(1.5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Calculate log probability of `x`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "sess.run(log_prob_inv_x + log_det_dzdx, feed_dict={x: np.array([[1., -1.]])})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Check:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "np.log(np.sqrt(1 / (2 * np.pi)**2)) - np.log(0.5) - np.log(1.5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "sess.close()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
